#!/usr/bin/python

# Copyright (c) 2010 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

# This is a fully automated VM test of autoupdate. Here's what it does:
# - Downloads the latest dev channel image
# - Creates a VMware image based on that image
# - Creates an update for the image under test
# - Creates a copy of that dev channel image that's one number higher than
#   the image under test (called the rollback image)
# - Creates an update for the rollback image
# - Launches a local HTTP server to pretend to be the AU server
# - Fires up the VM and waits for it to contact the AU server
# - AU server gives the VM the image under test update
# - Waits for the image to be installed, then reboots the VM
# - AU server gives the VM the rollback image
# - Waits for the image to be installed, then reboots the VM
# - Waits for the image to ping the AU server with the rollback image
# - Done!

# Run this program by passing it a path to the build directory you want
# to test (this directory should contain rootfs.image).

from xml.dom import minidom
from BaseHTTPServer import BaseHTTPRequestHandler
from BaseHTTPServer import HTTPServer

import os
import re
import signal
import socket
import string
import subprocess
import sys
import tempfile
import threading
import time

tmp_dir = '/tmp/au_vm_test'
scripts_dir = '../../scripts/'

original_version = '0.0.0.0'
test_version = '0.0.0.0'

# This class stores the state that the server is in and serves as glue
# between the AU server the control of the VMware process
class TestState(object):
  # States we can be in
  INIT, \
      INITIAL_WAIT, \
      TEST_DOWNLOAD, \
      TEST_WAIT, \
      ROLLBACK_DOWNLOAD, \
      ROLLBACK_WAIT = \
      xrange(6)

  def __init__(self, orig_vers, test_vers, rollback_vers, vm):
    self.reboot_timeout_seconds = 60 * 4
    self.orig_vers = orig_vers
    self.test_vers = test_vers
    self.rollback_vers = rollback_vers
    self.SetState(TestState.INIT)
    self.vm = vm
  
  def Die(self, message):
    print message
    self.vm.Destroy()
    # TODO(adlr): exit the entire process, not just this tread
    sys.exit(1)
  
  def SetState(self, state):
    self.state = state
  
  # Should be called to start the VM initially
  def Start(self):
    if self.state != TestState.INIT:
      self.Die('Start called while in bad state')
    self.SetState(TestState.INITIAL_WAIT)
    self.vm.Launch()
    # Kick off timer to wait for the AU ping
    self.timer = threading.Timer(self.reboot_timeout_seconds,
                                 self.StartupTimeout)
    self.timer.start()

  def StartupTimeout(self):
    self.Die('VM Failed to start and ping' + str(id(self)))
    # TODO(adlr): exit the entire process, not just this tread
    sys.exit(1)
  
  def FinishInstallTimeout(self):
    self.vm.Shutdown()
    self.vm.Launch()
    self.timer = threading.Timer(60 * 5,  # Sometimes VMWare is very slow
                                 self.StartupTimeout)
    self.timer.start()

  # Called by AU server when an update request comes in. Should return
  # the version that the server should return to the AU client, or
  # None if no update.
  def HandleUpdateRequest(self, from_version):
    print 'HandleUpdateRequest(%s) id:%s state:%s' % \
      (from_version, id(self), self.state)
    ret = None
    # Only cancel timer if we're waiting for the machine to startup
    if self.timer != None and self.state == TestState.INITIAL_WAIT and \
       from_version == self.orig_vers:
      print 'Successfully booted initial'
      self.timer.cancel()
      self.timer = None
    elif self.timer != None and self.state == TestState.TEST_WAIT and \
         from_version == self.test_vers:
      print 'Successfully booted test'
      self.timer.cancel()
      self.timer = None
    elif self.timer != None and self.state == TestState.ROLLBACK_WAIT and \
         from_version == self.rollback_vers:
      print 'Successfully booted rollback'
      self.timer.cancel()
      self.timer = None
      print 'All done!'
      # TODO(adlr): exit the entire process, not just this tread
      sys.exit(0)

    # Pick the version to return
    if from_version == self.orig_vers:
      ret = self.test_vers
    elif from_version == self.test_vers:
      ret = self.rollback_vers

    # Checks to make sure we move through states correctly
    if from_version == self.orig_vers:
      if self.state != TestState.INITIAL_WAIT and \
         self.state != TestState.TEST_DOWNLOAD and \
         self.state != TestState.INITIAL_WAIT:
        self.Die('Error: Request from %s while in state %s' %
                 (from_version, self.state))
    elif from_version == self.test_vers:
      if self.state != TestState.TEST_WAIT and \
         self.state != TestState.ROLLBACK_DOWNLOAD and \
         self.state != TestState.ROLLBACK_WAIT:
        self.Die('Error: Request from %s while in state %s' %
                 (from_version, self.state))
    else:
      print 'odd version to be pinged from: %s' % from_version
      print 'state is %s' % self.state

    # Update state if needed
    if self.state == TestState.INITIAL_WAIT:
      self.SetState(TestState.TEST_DOWNLOAD)
    elif self.state == TestState.TEST_WAIT:
      self.SetState(TestState.ROLLBACK_DOWNLOAD)

    if ret is not None:
      return ret
    print 'Ignoring update request while in state %s' % self.state
    return ''

  # Called by AU server when the AU client has finished downloading an image
  def ImageDownloadComplete(self):
    print 'ImageDownloadComplete()'
    valid_state = False
    if self.state == TestState.TEST_DOWNLOAD:
      valid_state = True
      self.SetState(TestState.TEST_WAIT)
    if self.state == TestState.ROLLBACK_DOWNLOAD:
      valid_state = True
      self.SetState(TestState.ROLLBACK_WAIT)
    if not valid_state:
      print 'Image download done called at invalid state'
      # TODO(adlr): exit the entire process, not just this tread
      sys.exit(1)
    # Put a timer to reboot the VM
    if self.timer is not None:
      self.timer.cancel()
      self.timer = None
    self.timer = threading.Timer(self.reboot_timeout_seconds,
                                 self.FinishInstallTimeout)
    self.timer.start()
    return

# This subclass of HTTPServer contains info about the versions of
# software that the AU server should know about. The AUServerHandler
# object(s) will access this data.
class AUHTTPServer(HTTPServer):
  def __init__(self, ip_port, klass):
    HTTPServer.__init__(self, ip_port, klass)
    self.update_info = {}
    self.files = {}
  
  # For a given version of the software, the URL, size, and hash of the update
  # that gives the user that version of the software.
  def AddUpdateInfo(self, version, url, size, the_hash):
    self.update_info[version] = (url, the_hash, size)
    return
  
  # For a given path part of a url, return to the client the file at file_path
  def AddServedFile(self, url_path, file_path):
    self.files[url_path] = file_path
  
  def SetTestState(self, test_state):
    self.test_state = test_state

# This class handles HTTP requests. POST requests are when the client
# is pinging to see if there's an update. GET requests are to download
# an update.
class AUServerHandler(BaseHTTPRequestHandler):
  def do_GET(self):
    self.send_response(200)
    self.end_headers()
    print 'GET: %s' % self.path

    if self.server.files[self.path] != None:
      print 'GET returning path %s' % self.server.files[self.path]
      f = open(self.server.files[self.path])
      while True:
        data = f.read(1024 * 1024 * 8)
        if not data:
          break
        self.wfile.write(data)
        self.wfile.flush()
      f.close()
      self.server.test_state.ImageDownloadComplete()
    else:
      print 'GET returning no path'
      self.wfile.write(self.path + '\n')
    return

  def do_POST(self):
    # Parse the form data posted
    post_length = int(self.headers.getheader('content-length'))
    post_data = self.rfile.read(post_length)
    
    update_dom = minidom.parseString(post_data)
    root = update_dom.firstChild
    query = root.getElementsByTagName('o:app')[0]
    client_version = query.getAttribute('version')
    print 'Got update request from %s' % client_version

    # Send response
    self.send_response(200)
    self.end_headers()

    new_version = self.server.test_state.HandleUpdateRequest(client_version)
    print 'Appropriate new version is: %s' % new_version

    if self.server.update_info[new_version] == None:
      print 'Not sure how to serve reply for %s' % new_version
      return

    payload = """<?xml version="1.0" encoding="UTF-8"?>
      <gupdate xmlns="http://www.google.com/update2/response" protocol="2.0">
        <app appid="{87efface-864d-49a5-9bb3-4b050a7c227a}" status="ok">
          <ping status="ok"/>
          <updatecheck
            codebase="%s"
            hash="%s"
            needsadmin="false"
            size="%s"
            status="ok"/>
        </app>
      </gupdate>
    """ % self.server.update_info[new_version]

    self.wfile.write(payload)
    return

# A wrapper for the vmplayer process. Can Launch/Shutdown a vm.
class vmplayer(object):
  def __init__(self, filename):
    self.filename = filename
  
  # Launch may (read: probably will) return before the OS has booted
  def Launch(self):
    self.process = subprocess.Popen(['/usr/bin/vmplayer', self.filename])
    self.running = True
  
  def Destroy(self):
    if self.running:
      self.Shutdown()

  # Shutdown will not return until the vmplayer process has fully terminated
  # and any cleanup is done.
  def Shutdown(self):
    # Pretend user sent Ctrl-C to the vmplayer process
    os.kill(self.process.pid, signal.SIGINT)

    # Wait while vmplayer saves the vm state...
    self.process.wait()
    
    # Delete the saved vm state
    # TODO(adlr): remove the state file from the disk
    self.process = None
    subprocess.check_call(['/bin/sed', '-i', '/checkpoint.vmState/d',
                           self.filename])
    self.running = False

def MakePath(path):
  subprocess.check_call(['/bin/mkdir', '-p', path])

def DownloadLatestImage(out_path):
  url = 'http://codf30.jail.google.com/internal/archive/' + \
        'x86-image-official/LATEST-dev-channel/image.zip' + \
        '.NOT_SAFE_FOR_USB_INSTALL'
  url = 'http://www.corp.google.com/~adlr/adlr_test_orig.zip'
  subprocess.check_call(['/usr/bin/wget', '-O', out_path, url])

def UnzipImage(path, directory):
  subprocess.check_call(['/usr/bin/unzip', path, '-d', directory])

# Create a stateful partition with a text file that points the AU client
# to localhost at local_port (which is this very script).
def CreateDefaultStatefulPartition(local_ip, local_port, out_dir, out_file):
  # Create sparse file for partition
  part_size = 512 * 1024 * 1024  # bytes, so 500 MiB

  subprocess.check_call(["""#!/bin/bash
  set -ex
  OUT_FILE="%s"
  STATE_DIR="%s"
  SIZE="%s"
  dd if=/dev/zero of="$OUT_FILE" bs=1 count=1 seek=$(($SIZE - 1))
  mkfs.ext4 -F "$OUT_FILE"
  mkdir -p "$STATE_DIR"
  sudo mount -o loop "$OUT_FILE" "$STATE_DIR"
  sudo mkdir -p "$STATE_DIR/etc"
  cat <<EOF |sudo dd of="$STATE_DIR/etc/lsb-release"
CHROMEOS_AUSERVER=http://%s:%s/update
HTTP_SERVER_OVERRIDE=true
EOF
  for i in "$STATE_DIR/etc/lsb-release" "$STATE_DIR/etc" "$STATE_DIR"; do
    sudo chown root:root "$i"
    if [ -d "$i" ]; then
      sudo chmod 0755 "$i"
    else
      sudo chmod 0644 "$i"
    fi
  done
  sudo umount -d "$STATE_DIR"
  """ % (out_file, out_dir + '/state', part_size, local_ip, local_port)],
    shell=True)
  return
  

def CreateVMForImage(image_dir):
  subprocess.check_call([scripts_dir + 'image_to_vmware.sh', '--from',
                         image_dir, '--to', image_dir, '--state_image',
                         image_dir + '/state.image'])
  return

# Returns (size, hash, path) of the generated image.gz for the given rootfs
def CreateUpdateForImage(rootfs_image):
  output = subprocess.Popen([scripts_dir + 'mk_memento_images.sh',
                             rootfs_image],
                            stdout=subprocess.PIPE).communicate()[0]
  matches = re.search('Success. hash is ([^\n]+)', output)
  the_hash = matches.group(1)
  path = os.path.dirname(rootfs_image) + '/update.gz'
  size = os.path.getsize(path)
  return (size, the_hash, path)

# Modify rootfs 'root_img' to have a new version new_version
def ModifyImageForRollback(root_img, new_version):
  subprocess.check_call(["""#!/bin/bash
  set -ex
  DIR=$(mktemp -d)
  sudo mount -o loop "%s" $DIR
  # update versions in lsb-release
  sudo sed -i \\
      -e 's/\\(^GOOGLE_RELEASE=\\|CHROMEOS_RELEASE_VERSION=\\).*/\\1%s/' \\
      -e 's/^\\(CHROMEOS_RELEASE_DESCRIPTION=.*\\)/\\1-ROLLBACK/' \\
      "$DIR"/etc/lsb-release
  sudo umount -d $DIR
  """ % (root_img, new_version)], shell=True)

# Returns the release version of a rootfs (e.g. 0.6.39.201003241739-a1)
def GetVersionForRootfs(rootfs_image):
  mount_dir = tempfile.mkdtemp()
  subprocess.check_call(['sudo', 'mount', '-o', 'loop,ro',
                         rootfs_image, mount_dir])
  version = subprocess.Popen(['awk', '-F', '=',
                              '/GOOGLE_RELEASE=/{print $2}',
                              mount_dir + '/etc/lsb-release'],
                             stdout=subprocess.PIPE).communicate()[0].rstrip()
  subprocess.check_call(['sudo', 'umount', '-d', mount_dir])
  subprocess.check_call(['sudo', 'rm', '-rf', mount_dir])
  return version

# For a given version, increment the last number by 1. E.g.:
# IncrementVersionNumber('0.23.144.842') = '0.23.144.843'
def IncrementVersionNumber(version):
  parts = version.split('.')
  parts[-1] = str(int(parts[-1]) + 1)
  return string.join(parts, '.')

def UnpackRootfs(directory):
  subprocess.check_call(["""#!/bin/bash -x
    cd "%s"
    ./unpack_partitions.sh chromiumos_image.bin
    mv -f part_3 rootfs.image
    """ % directory], shell=True)

def main():
  if len(sys.argv) != 2:
    print 'usage: %s path/to/new/image/dir' % sys.argv[0]
    sys.exit(1)
  orig_dir = tmp_dir + '/orig'
  new_dir = sys.argv[1]
  rollback_dir = tmp_dir + '/rollback'
  
  state_image = orig_dir + '/state.image'
  port = 8080

  MakePath(tmp_dir)

  # Download latest dev channel release
  orig_zip = tmp_dir + '/orig.zip'
  DownloadLatestImage(orig_zip)
  UnzipImage(orig_zip, orig_dir)
  UnpackRootfs(orig_dir)
  orig_version = GetVersionForRootfs(orig_dir + '/rootfs.image')
  print 'Have original image at version: %s' % orig_version

  # Create new AU image
  print 'Creating update.gz for test image'
  UnpackRootfs(new_dir)
  new_update_details = CreateUpdateForImage(new_dir + '/rootfs.image')
  new_version = GetVersionForRootfs(new_dir + '/rootfs.image')
  print 'Have test image at version: %s' % new_version

  # Create rollback image
  rollback_version = IncrementVersionNumber(new_version)
  print 'Creating rollback image'
  UnzipImage(orig_zip, rollback_dir)
  UnpackRootfs(rollback_dir)
  ModifyImageForRollback(rollback_dir + '/rootfs.image', rollback_version)
  print 'Creating update.gz for rollback image'
  rollback_update_details = CreateUpdateForImage(rollback_dir + '/rootfs.image')
  print 'Have rollback image at version: %s' % rollback_version

  CreateDefaultStatefulPartition(socket.gethostname(), port, orig_dir,
                                 state_image)
  CreateVMForImage(orig_dir)
  
  player = vmplayer(orig_dir + '/chromeos.vmx')
  
  test_state = TestState(orig_version, new_version, rollback_version, player)
  
  server = AUHTTPServer((socket.gethostname(), port), AUServerHandler)
  
  base_url = 'http://%s:%s' % (socket.gethostname(), port)
  
  server.SetTestState(test_state)
  server.AddUpdateInfo(new_version, base_url + '/' + new_version,
                       new_update_details[0], new_update_details[1])
  server.AddUpdateInfo(rollback_version, base_url + '/' + rollback_version,
                       rollback_update_details[0], rollback_update_details[1])
  server.AddServedFile('/' + new_version, new_update_details[2])
  server.AddServedFile('/' + rollback_version, rollback_update_details[2])
  
  test_state.Start()
  print 'Starting server, use <Ctrl-C> to stop'
  server.serve_forever()

if __name__ == '__main__':
  main()
